import gym
from gym import spaces
import numpy as np
import pdb
from .multiagentenv import MultiAgentEnv

class LavaPath(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self):
        super(LavaPath, self).__init__()
        self.action_space = spaces.Discrete(5)  # action space: up, down, left, right, stay
        self.observation_space = spaces.Box(low=-1, high=3, shape=(1, 11), dtype=np.float32)  # partially observation
        """
        padding: -1
        road: 0
        lava: 1
        goal for agent_1: 2
        goal for agent_2: 3
        """
        self.positions = np.array([[5, 2], [6, 3]], dtype=np.uint8)  # [agent1 position, agent2 position]
        self.world = np.zeros((12, 12), dtype=np.float32)
        self.done = False
        self.t = 0
        self.length = 60
        self.agents = [0, 1]
        self.dones_agent = [False, False]

        self.reset()


    def step(self, actions):
        """
        up: 0
        down: 1
        left: 2
        right: 3
        stay: 4
        """
        self.t += 1
        self.get_next_position(actions)
        self.if_done()
        reward = self.get_reward()
        observation = self.get_observation()
        info = self.get_info()
        return observation, reward, self.done, info

    def reset(self):
        self.done = False
        self.dones_agent = [False, False]
        self.t = 0
        # Reset the state of the environment to an initial state
        self.world = np.zeros((12, 12), dtype=np.float32)
        # Set padding area
        self.world[0, :] = -1.
        self.world[:, 0] = -1.
        self.world[-1, :] = -1.
        self.world[:, -1] = -1.
        # Set lava
        self.world[2:5, 3:9] = 1.
        self.world[7:10, 3:9] = 1.
        # Set goal
        self.world[5, -2] = 2.
        self.world[6, -2] = 3.
        # Set starting point
        self.positions = np.array([[6, 2], [5, 2]], dtype=np.uint8)  # [agent1 position, agent2 position]
        observation = self.get_observation()
        return observation

    def get_observation(self):
        coor1 = self.positions[0]
        coor2 = self.positions[1]
        x1 = coor1[0]
        y1 = coor1[1]
        x2 = coor2[0]
        y2 = coor2[1]

        obs1 = self.world[x1-1:x1+2, y1-1:y1+2].reshape(-1)
        obs1 = np.expand_dims(np.concatenate([obs1, coor1-1]), axis=0).astype(np.float32)
        obs2 = self.world[x2-1:x2+2, y2-1:y2+2].reshape(-1)
        obs2 = np.expand_dims(np.concatenate([obs2, coor2-1]), axis=0).astype(np.float32)

        observation = [obs1, obs2]
        return observation

    def get_state(self):
        coor1 = self.positions[0]
        coor2 = self.positions[1]
        x1 = coor1[0]
        y1 = coor1[1]
        x2 = coor2[0]
        y2 = coor2[1]

        state = self.world.reshape(-1)
        state = np.concatenate([state, coor1-1, coor2-1]).astype(np.float32)
        return state



    def get_reward(self):
        """
        each agent get -10 reward if they step into lava
        get 40 reward if the goal reaches
        """
        lava = 0.
        goal = 0.
        coor1 = self.positions[0]
        coor2 = self.positions[1]
        x1 = coor1[0]
        y1 = coor1[1]
        x2 = coor2[0]
        y2 = coor2[1]
        # lava
        if x1 >= 2 and x1 <= 4 and y1 >= 3 and y1 <= 8:
            lava -= 10.
        elif x1 >= 7 and x1 <= 9 and y1 >= 3 and y1 <= 8:
            lava -= 10.
        if x2 >= 2 and x2 <= 4 and y2 >= 3 and y2 <= 8:
            lava -= 10.
        elif x2 >= 7 and x2 <= 9 and y2 >= 3 and y2 <= 8:
            lava -= 10.
        # goal
        if x1 == 5 and y1 == 10 and x2 == 6 and y2 == 10:
            goal += 40.
        elif self.t >= self.length:
            distance1 = (abs(x1 - 5) + abs(y1 - 10))
            distance2 = (abs(x2 - 6) + abs(y2 - 10))
            goal = 40. - distance1 - distance2
        # total reward
        reward = lava + goal
        return reward

    def get_next_position(self, actions):  # Environmental Dynamic
        agent1_act = actions[0]
        avail_act1 = self.get_avail_agent_actions(0)
        agent2_act = actions[1]
        avail_act2 = self.get_avail_agent_actions(1)

        coor1 = self.positions[0]
        coor2 = self.positions[1]
        x1 = coor1[0]
        y1 = coor1[1]
        x2 = coor2[0]
        y2 = coor2[1]

        # Collision dynamic
        # x axis
        if y1 == y2:
            if x1 == x2 - 1: # agent1 is on the left
                if agent1_act == 3 and agent2_act == 2:  # smash together
                    if avail_act1[2]:
                        self.positions[0][0] -= 1
                    if avail_act2[3]:
                        self.positions[1][0] += 1
                elif agent1_act == 3:  # agent1 pushes agent2
                    if avail_act1[3] and avail_act2[3]:
                        self.positions[0][0] += 1
                        self.positions[1][0] += 1
                elif agent2_act == 2:  # agent2 pushes agent1
                    if avail_act1[2] and avail_act2[2]:
                        self.positions[0][0] -= 1
                        self.positions[1][0] -= 1
            elif x2 == x1 - 1:
                if agent1_act == 2 and agent2_act == 3:  # smash together
                    if avail_act1[3]:
                        self.positions[0][0] += 1
                    if avail_act2[2]:
                        self.positions[1][0] -= 1
                elif agent1_act == 2:  # agent2 pushes agent1
                    if avail_act1[2] and avail_act2[2]:
                        self.positions[1][0] -= 1
                        self.positions[0][0] -= 1
                elif agent2_act == 3:  # agent2 pushes agent1
                    if avail_act1[3] and avail_act2[3]:
                        self.positions[0][0] += 1
                        self.positions[1][0] += 1
            elif x2 == x1 - 2 and agent2_act == 3 and agent1_act == 2:  # conflict
                pass
            elif x1 == x2 - 2 and agent1_act == 3 and agent2_act == 2:  # conflict
                pass
            else: # no contact
                if agent1_act == 3 and avail_act1[3]:
                    self.positions[0][0] += 1
                elif agent1_act == 2 and avail_act1[2]:
                    self.positions[0][0] -= 1
                if agent2_act == 3 and avail_act2[3]:
                    self.positions[1][0] += 1
                elif agent2_act == 2 and avail_act2[2]:
                    self.positions[1][0] -= 1
            # y axis movement
            if agent1_act == 0 and avail_act1[0]:
                self.positions[0][1] += 1
            elif agent1_act == 1 and avail_act1[1]:
                self.positions[0][1] -= 1
            if agent2_act == 0 and avail_act2[0]:
                self.positions[1][1] += 1
            elif agent2_act == 1 and avail_act2[1]:
                self.positions[1][1] -= 1
        # y axis
        elif x1 == x2:
            if y1 == y2 - 1:
                if agent1_act == 0 and agent2_act == 1:  # smash together
                    if avail_act1[1]:
                        self.positions[0][1] -= 1
                    if avail_act2[0]:
                        self.positions[1][1] += 1
                elif agent1_act == 0:  # agent1 pushes agent2
                    if avail_act1[0] and avail_act2[0]:
                        self.positions[0][1] += 1
                        self.positions[1][1] += 1
                elif agent2_act == 1:  # agent2 pushes agent1
                    if avail_act1[1] and avail_act2[1]:
                        self.positions[0][1] -= 1
                        self.positions[1][1] -= 1
            elif y2 == y1 - 1:
                if agent1_act == 1 and agent2_act == 0:  # smash together
                    if avail_act1[0]:
                        self.positions[0][1] += 1
                    if avail_act2[1]:
                        self.positions[1][1] -= 1
                elif agent1_act == 1:  # agent1 pushes agent2
                    if avail_act1[1] and avail_act2[1]:
                        self.positions[0][1] -= 1
                        self.positions[1][1] -= 1
                elif agent2_act == 0:  # agent2 pushes agent1
                    if avail_act1[0] and avail_act2[0]:
                        self.positions[0][1] += 1
                        self.positions[1][1] += 1
            elif y2 == y1 - 2 and agent2_act == 0 and agent1_act == 1:  # conflict
                pass
            elif y1 == y2 - 2 and agent1_act == 0 and agent2_act == 1:  # conflict
                pass
            else:  # no contact
                if agent1_act == 0 and avail_act1[0]:
                    self.positions[0][1] += 1
                elif agent1_act == 1 and avail_act1[1]:
                    self.positions[0][1] -= 1
                if agent2_act == 0 and avail_act2[0]:
                    self.positions[1][1] += 1
                elif agent2_act == 1 and avail_act2[1]:
                    self.positions[1][1] -= 1
            # x axis movement
            if agent1_act == 2 and avail_act1[2]:
                self.positions[0][0] -= 1
            elif agent1_act == 3 and avail_act1[3]:
                self.positions[0][0] += 1
            if agent2_act == 2 and avail_act2[2]:
                self.positions[1][0] -= 1
            elif agent2_act == 3 and avail_act2[3]:
                self.positions[1][0] += 1
        else:  # no contact
            # conflict
            if (x1 == x2 - 1 and x1 == y2 - 1) and \
                    ((agent1_act == 0 and agent2_act == 2) or (agent1_act == 3 and agent2_act == 1)):
                pass
            elif (x1 == x2 + 1 and x1 == y2 + 1) and \
                    ((agent1_act == 1 and agent2_act == 3) or (agent1_act == 2 and agent2_act == 0)):
                pass
            elif (x1 == x2 - 1 and x1 == y2 + 1) and \
                    ((agent1_act == 3 and agent2_act == 0) or (agent1_act == 1 and agent2_act == 2)):
                pass
            elif (x1 == x2 + 1 and x1 == y2 - 1) and \
                    ((agent1_act == 0 and agent2_act == 3) or (agent1_act == 2 and agent2_act == 1)):
                pass
            else:  # no conflict
                # agent1
                if agent1_act == 0 and avail_act1[0]:
                    self.positions[0][1] += 1
                elif agent1_act == 1 and avail_act1[1]:
                    self.positions[0][1] -= 1
                elif agent1_act == 2 and avail_act1[2]:
                    self.positions[0][0] -= 1
                elif agent1_act == 3 and avail_act1[3]:
                    self.positions[0][0] += 1
                # agent2
                if agent2_act == 0 and avail_act2[0]:
                    self.positions[1][1] += 1
                elif agent2_act == 1 and avail_act2[1]:
                    self.positions[1][1] -= 1
                elif agent2_act == 2 and avail_act2[2]:
                    self.positions[1][0] -= 1
                elif agent2_act == 3 and avail_act2[3]:
                    self.positions[1][0] += 1

    def if_done(self):
        coor1 = self.positions[0]
        coor2 = self.positions[1]
        x1 = coor1[0]
        y1 = coor1[1]
        x2 = coor2[0]
        y2 = coor2[1]
        # if in padding area
        if x1 <= 0 or x1 >= 11 or x2 <= 0 or x2 >= 11:
            self.done = True
        if y1 <= 0 or y1 >= 11 or y2 <= 0 or y2 >= 11:
            self.done = True
        # if in lava
        if (x1 >= 2 and x1 <= 4 and y1 >= 3 and y1 <= 8) or (x2 >= 2 and x2 <= 4 and y2 >= 3 and y2 <= 8):
            self.done = True
        if (x1 >= 7 and x1 <= 9 and y1 >= 3 and y1 <= 8) or (x2 >= 7 and x2 <= 9 and y2 >= 3 and y2 <= 8):
            self.done = True
        # conflict
        if x1 == x2 and y1 == y2:
            self.done = True
        
        # reach the goal
        if x1 == 5 and y1 == 10 and x2 == 6 and y2 == 10:
            self.done = True

        if x1 ==5 and y1 == 10:
            self.dones_agent[0] = True
        else:
            self.dones_agent[0] = False
        
        if x2 == 6 and y2 == 10:
            self.dones_agent[1] = True
        else:
            self.dones_agent[1] = False

        # timeout
        if self.t >= self.length:
            self.done = True

    def get_avail_agent_actions(self, agent_id):
        avail_actions = np.array([False, False, False, False, False])
        coor = self.positions[agent_id]
        x = coor[0]
        y = coor[1]
        if x > 1:
            avail_actions[2] = True
        if x < 10:
            avail_actions[3] = True
        if y > 1:
            avail_actions[1] = True
        if y < 10:
            avail_actions[0] = True
        
        avail_actions[4] = self.dones_agent[agent_id]

        return avail_actions

    def get_info(self):
        return {}

    def render(self, mode='human', close=False):
        # Render the environment to the screen

        return

    def base_img(self):
        return

    def draw_grid(self):
        return

    def draw_agent(self):
        return

class LavaPathWrapper(MultiAgentEnv):

    def __init__(self, seed=0):
        self.env = LavaPath()
        self.obs = self.env.reset()
        self.l_seed = seed
        self.episode_limit = self.env.length

    def step(self, actions):
        """ Returns reward, terminated, info """ 

        self.obs, rewards, done, info = self.env.step(actions)
        return rewards, done, info


    def get_obs(self):
        """ Returns all agent observations in a list """
        cat_obs = []
        for a in self.env.agents:
            cat_obs.append(self.get_obs_agent(a))
        return np.array(cat_obs)

    def get_obs_agent(self, agent_id):
        """ Returns observation for agent_id """
        o = self.obs[agent_id]

        return o.tolist()[0]

    def get_obs_size(self):
        """ Returns the shape of the observation """
        return len(self.get_obs_agent(0))

    def get_state(self):
        
        state = []
        for o in self.get_obs():
            state.extend(o)
        return state

    def get_state_size(self):
        """ Returns the shape of the state"""

        return len(self.get_state())

    def get_avail_actions(self):
        
        actions = []
        for a in self.env.agents:
            actions.append(self.get_avail_agent_actions(a))
        return actions

    def get_avail_agent_actions(self, agent_id):
        """ Returns the available actions for agent_id """
        avail_actions = self.env.get_avail_agent_actions(agent_id)
        return [1 if a else 0 for a in avail_actions]


    def get_total_actions(self):
        """ Returns the total number of actions an agent could ever take """
        # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
        #return self.env.action_space
        # In this envrionment, we simply set to 5. ^_^
        return 5

    def reset(self):
        """ Returns initial observations and states"""
        self.obs = self.env.reset()

    def render(self):
        raise NotImplementedError

    def close(self):
        #self.env.close()
        pass

    def seed(self):
        #self.env.seed(self.r_seed)
        pass

    def save_replay(self):
        raise NotImplementedError

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": len(self.env.agents),
                    "episode_limit": self.episode_limit}

        return env_info
    def get_stats(self):
        pass